{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "a07356b3-744a-4319-9fec-cd62f37fa865",
   "metadata": {},
   "source": [
    "# Data Preprocessing with DataFrame"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "7d72fa78",
   "metadata": {},
   "source": [
    ">The following codes are demos only. It's **NOT for production** due to system security concerns, please **DO NOT** use it directly in production."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "e89e9dcf",
   "metadata": {},
   "source": [
    "It is recommended to use [jupyter](https://jupyter.org/) to run this tutorial."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "8ff05a38-2211-4240-a9db-0d79c813ab99",
   "metadata": {},
   "source": [
    "Secretflow provides a variety of preprocessing tools to process data."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "25ec1569-f9f7-4f27-90b8-a6c7feab28e2",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Preparation\n",
    "\n",
    "Initialize secretflow and create two parties alice and bob."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "83e1596d-8ca1-40ae-9681-7254c563ff7e",
   "metadata": {},
   "source": [
    "> 💡 Before using preprocessing, you may need to get to know secretflow's [DataFrame](../user_guide/preprocessing/DataFrame.ipynb)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ad74320-2c3a-4c86-aea4-6688d96d2230",
   "metadata": {},
   "outputs": [],
   "source": [
    "import secretflow as sf\n",
    "\n",
    "# Check the version of your SecretFlow\n",
    "print('The version of SecretFlow: {}'.format(sf.__version__))\n",
    "\n",
    "# In case you have a running secretflow runtime already.\n",
    "sf.shutdown()\n",
    "\n",
    "sf.init(['alice', 'bob'], address='local')\n",
    "alice = sf.PYU('alice')\n",
    "bob = sf.PYU('bob')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "94c83c7b-417a-4772-9de1-2efc589cd89f",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Data Preparation"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "86168ad6-2fe0-4410-b59c-fd65cbe8ea9b",
   "metadata": {},
   "source": [
    "Here we use [iris](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_iris.html) as example data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9d7d70b8-2d12-40c0-891e-d42cbd567cab",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sepal length (cm)</th>\n",
       "      <th>sepal width (cm)</th>\n",
       "      <th>petal length (cm)</th>\n",
       "      <th>petal width (cm)</th>\n",
       "      <th>target</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>5.1</td>\n",
       "      <td>3.5</td>\n",
       "      <td>1.4</td>\n",
       "      <td>0.2</td>\n",
       "      <td>setosa</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>4.9</td>\n",
       "      <td>NaN</td>\n",
       "      <td>1.4</td>\n",
       "      <td>0.2</td>\n",
       "      <td>setosa</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>4.7</td>\n",
       "      <td>3.2</td>\n",
       "      <td>1.3</td>\n",
       "      <td>0.2</td>\n",
       "      <td>setosa</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4.6</td>\n",
       "      <td>3.1</td>\n",
       "      <td>1.5</td>\n",
       "      <td>0.2</td>\n",
       "      <td>setosa</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5.0</td>\n",
       "      <td>3.6</td>\n",
       "      <td>1.4</td>\n",
       "      <td>0.2</td>\n",
       "      <td>setosa</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>145</th>\n",
       "      <td>6.7</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.2</td>\n",
       "      <td>2.3</td>\n",
       "      <td>virginica</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>146</th>\n",
       "      <td>6.3</td>\n",
       "      <td>2.5</td>\n",
       "      <td>5.0</td>\n",
       "      <td>1.9</td>\n",
       "      <td>virginica</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>147</th>\n",
       "      <td>6.5</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.2</td>\n",
       "      <td>2.0</td>\n",
       "      <td>virginica</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>148</th>\n",
       "      <td>6.2</td>\n",
       "      <td>3.4</td>\n",
       "      <td>5.4</td>\n",
       "      <td>2.3</td>\n",
       "      <td>virginica</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>149</th>\n",
       "      <td>5.9</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.1</td>\n",
       "      <td>1.8</td>\n",
       "      <td>virginica</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>150 rows × 5 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     sepal length (cm)  sepal width (cm)  petal length (cm)  petal width (cm)  \\\n",
       "0                  5.1               3.5                1.4               0.2   \n",
       "1                  4.9               NaN                1.4               0.2   \n",
       "2                  4.7               3.2                1.3               0.2   \n",
       "3                  4.6               3.1                1.5               0.2   \n",
       "4                  5.0               3.6                1.4               0.2   \n",
       "..                 ...               ...                ...               ...   \n",
       "145                6.7               3.0                5.2               2.3   \n",
       "146                6.3               2.5                5.0               1.9   \n",
       "147                6.5               3.0                5.2               2.0   \n",
       "148                6.2               3.4                5.4               2.3   \n",
       "149                5.9               3.0                5.1               1.8   \n",
       "\n",
       "        target  \n",
       "0       setosa  \n",
       "1       setosa  \n",
       "2       setosa  \n",
       "3       setosa  \n",
       "4       setosa  \n",
       "..         ...  \n",
       "145  virginica  \n",
       "146  virginica  \n",
       "147  virginica  \n",
       "148  virginica  \n",
       "149  virginica  \n",
       "\n",
       "[150 rows x 5 columns]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "from sklearn.datasets import load_iris\n",
    "\n",
    "iris = load_iris(as_frame=True)\n",
    "data = pd.concat([iris.data, iris.target], axis=1)\n",
    "\n",
    "# In order to facilitate the subsequent display,\n",
    "# here we first set some data to None.\n",
    "data.iloc[1, 1] = None\n",
    "data.iloc[100, 1] = None\n",
    "\n",
    "# Restore target to its original name.\n",
    "data['target'] = data['target'].map({0: 'setosa', 1: 'versicolor', 2: 'virginica' })\n",
    "data"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "e109ddee-3998-4c1a-8bb6-797cfe9a31f0",
   "metadata": {},
   "source": [
    "Create a vertical partitioned DataFrame."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "90206eca-f690-412d-9c09-e70e6830f6ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "import tempfile\n",
    "from secretflow.data.vertical import read_csv as v_read_csv\n",
    "\n",
    "# Vertical partitioning.\n",
    "v_alice, v_bob = data.iloc[:, :2], data.iloc[:, 2:]\n",
    "\n",
    "# Save to temprary files.\n",
    "_, alice_path = tempfile.mkstemp()\n",
    "_, bob_path = tempfile.mkstemp()\n",
    "v_alice.to_csv(alice_path, index=False)\n",
    "v_bob.to_csv(bob_path, index=False)\n",
    "\n",
    "\n",
    "df = v_read_csv({alice: alice_path, bob: bob_path})"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "338872d7",
   "metadata": {},
   "source": [
    "You can also create a horizontal partitioned DataFrame, which works the same with vertical partitioning for subsequent steps."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "2fd99e07",
   "metadata": {},
   "outputs": [],
   "source": [
    "# from secretflow.data.horizontal import read_csv as h_read_csv\n",
    "# from secretflow.security.aggregation import PlainAggregator\n",
    "# from secretflow.security.compare import PlainComparator\n",
    "\n",
    "# # Horizontal partitioning.\n",
    "# h_alice, h_bob = data.iloc[:70, :], data.iloc[70:, :]\n",
    "\n",
    "# # Save to temorary files.\n",
    "# _, h_alice_path = tempfile.mkstemp()\n",
    "# _, h_bob_path = tempfile.mkstemp()\n",
    "# h_alice.to_csv(h_alice_path, index=False)\n",
    "# h_bob.to_csv(h_bob_path, index=False)\n",
    "\n",
    "# df = h_read_csv(\n",
    "#     {alice: h_alice_path, bob: h_bob_path},\n",
    "#     aggregator=PlainAggregator(alice),\n",
    "#     comparator=PlainComparator(alice),\n",
    "# )"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "b346a56c-6a5b-4b59-8c36-99344febdb51",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Preprocessing\n",
    "\n",
    "Secretflow provides missing value filling, standardization, categorical features encoding, discretization .etc, which are similar to sklearn's preprocessing."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "c22e1dc7-508c-487b-8434-fb46f2cd0bcf",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Missing value filling\n",
    "\n",
    "DataFrame provides the fillna method, which can fill in missing values in the same way as pandas."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f53eb206-7ba5-44ab-9b70-713af21ae28c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "148"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Before filling, the sepal width (cm) is missing in two positions.\n",
    "df.count()['sepal width (cm)']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b8c94841-73b3-4eb8-913d-af2f29c15ee2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "150"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Fill sepal width (cm) with 10.\n",
    "df.fillna(value={'sepal width (cm)': 10}).count()['sepal width (cm)']"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "fd909a7a-0cde-4cdc-b8a6-bbb09a909773",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Standardization\n",
    "#### Scaling features to a range\n",
    "\n",
    "Secretflow provides `MinMaxScaler` for scaling features to lie between a given minimum and maximum value. The input and output of MinMaxScaler are both `DataFrame`.\n",
    "\n",
    "Here is an exmaple to scale `sepal length (cm)` to the [0, 1] range."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "51043015-d6e8-4461-b205-eeb3a966fe83",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Min:  sepal length (cm)    0.0\n",
      "dtype: float64\n",
      "Max:  sepal length (cm)    1.0\n",
      "dtype: float64\n"
     ]
    }
   ],
   "source": [
    "from secretflow.preprocessing import MinMaxScaler\n",
    "\n",
    "scaler = MinMaxScaler()\n",
    "\n",
    "scaled_sepal_len = scaler.fit_transform(df['sepal length (cm)'])\n",
    "\n",
    "print('Min: ', scaled_sepal_len.min())\n",
    "print('Max: ', scaled_sepal_len.max())"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "5a006c19",
   "metadata": {},
   "source": [
    "#### Variance scaling\n",
    "Secretflow provides `StandardScaler` for variance scaling. The input and output of StandardScaler are both DataFrames.\n",
    "\n",
    "Here is an exmaple to scale `sepal length (cm)` to unit variance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "cabf7b70",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Min:  sepal length (cm)   -1.870024\n",
      "dtype: float64\n",
      "Max:  sepal length (cm)    2.492019\n",
      "dtype: float64\n"
     ]
    }
   ],
   "source": [
    "from secretflow.preprocessing import StandardScaler\n",
    "\n",
    "scaler = StandardScaler()\n",
    "\n",
    "scaled_sepal_len = scaler.fit_transform(df['sepal length (cm)'])\n",
    "\n",
    "print('Min: ', scaled_sepal_len.min())\n",
    "print('Max: ', scaled_sepal_len.max())"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "4d48a464-10df-4a86-82a3-e51fb7d10579",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Encoding categorical features\n",
    "\n",
    "#### OneHot encoding\n",
    "Secretflow provides `OneHotEncoder` for OneHot encoding. The input and output of OneHotEncoder are `DataFrame`.\n",
    "\n",
    "Here is an example to encode target with onehot."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "fe2023a4-6e6b-4ed8-892a-78df980e7ded",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Columns:  Index(['target_setosa', 'target_versicolor', 'target_virginica'], dtype='object')\n",
      "Min: \n",
      " target_setosa        0.0\n",
      "target_versicolor    0.0\n",
      "target_virginica     0.0\n",
      "dtype: float64\n",
      "Max: \n",
      " target_setosa        1.0\n",
      "target_versicolor    1.0\n",
      "target_virginica     1.0\n",
      "dtype: float64\n"
     ]
    }
   ],
   "source": [
    "from secretflow.preprocessing import OneHotEncoder\n",
    "\n",
    "onehot_encoder = OneHotEncoder()\n",
    "onehot_target = onehot_encoder.fit_transform(df['target'])\n",
    "\n",
    "print('Columns: ', onehot_target.columns)\n",
    "print('Min: \\n', onehot_target.min())\n",
    "print('Max: \\n', onehot_target.max())"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "3d2434aa-9f1f-40ed-b513-78c179e24635",
   "metadata": {
    "tags": []
   },
   "source": [
    "#### Label encoding\n",
    "\n",
    "secretflow provides `LabelEncoder` for encoding target labels with value between 0 and n_classes-1. The input and output of LabelEncoder are `DataFrame`.\n",
    "\n",
    "Here is an example to encode target to [0, n_classes-1]."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "3c55619d-c12a-444d-b3a0-ab05d71dacc4",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Columns:  Index(['target'], dtype='object')\n",
      "Min: \n",
      " target    0\n",
      "dtype: int64\n",
      "Max: \n",
      " target    2\n",
      "dtype: int64\n"
     ]
    }
   ],
   "source": [
    "from secretflow.preprocessing import LabelEncoder\n",
    "\n",
    "label_encoder = LabelEncoder()\n",
    "encoded_label = label_encoder.fit_transform(df['target'])\n",
    "\n",
    "print('Columns: ', encoded_label.columns)\n",
    "print('Min: \\n', encoded_label.min())\n",
    "print('Max: \\n', encoded_label.max())"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "00d70482",
   "metadata": {},
   "source": [
    "### Discretization\n",
    "\n",
    "SecretFlow provides `KBinsDiscretizer` for partitioning continuous features into discrete values. The input and output of KBinsDiscretizer are both `DataFrame`.\n",
    "\n",
    "Here is an example to partition `petal length (cm)` to 5 bins."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "a3408341",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Min: \n",
      " petal length (cm)    0.0\n",
      "dtype: float64\n",
      "Max: \n",
      " petal length (cm)    4.0\n",
      "dtype: float64\n"
     ]
    }
   ],
   "source": [
    "from secretflow.preprocessing import KBinsDiscretizer\n",
    "\n",
    "estimator = KBinsDiscretizer(n_bins=5)\n",
    "binned_petal_len = estimator.fit_transform(df['petal length (cm)'])\n",
    "\n",
    "print('Min: \\n', binned_petal_len.min())\n",
    "print('Max: \\n', binned_petal_len.max())"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "343704ae",
   "metadata": {},
   "source": [
    "#### WOE encoding\n",
    "\n",
    "secretflow provides `VertWoeBinning` to bin the features into buckets by quantile or chimerge method, and calculate the woe value and iv value in each bucket. And `VertWOESubstitution` can substitute the features with the woe value.\n",
    "\n",
    "Here is an example to encode features to woe."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "16e22ed1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "orig ds in alice:\n",
      "             x1        x2        x3\n",
      "0    -0.514226  0.730010 -0.730391\n",
      "1    -0.725537  0.482244 -0.823223\n",
      "2     0.608353 -0.071102 -0.775098\n",
      "3    -0.686642  0.160470  0.914477\n",
      "4    -0.198111  0.212909  0.950474\n",
      "...        ...       ...       ...\n",
      "9995 -0.367246 -0.296454  0.558596\n",
      "9996  0.010913  0.629268 -0.384093\n",
      "9997 -0.238097  0.904069 -0.344859\n",
      "9998  0.453686 -0.375173  0.899238\n",
      "9999 -0.776015 -0.772112  0.012110\n",
      "\n",
      "[10000 rows x 3 columns]\n",
      "orig ds in bob:\n",
      "            x18       x19       x20  y\n",
      "0     0.810261  0.048303  0.937679  1\n",
      "1     0.312728  0.526637  0.589773  1\n",
      "2     0.039087 -0.753417  0.516735  0\n",
      "3    -0.855979  0.250944  0.979465  1\n",
      "4    -0.238805  0.243109 -0.121446  1\n",
      "...        ...       ...       ... ..\n",
      "9995 -0.847253  0.069960  0.786748  1\n",
      "9996 -0.502486 -0.076290 -0.604832  1\n",
      "9997 -0.424209  0.434947  0.998955  1\n",
      "9998  0.914291 -0.473056  0.616257  1\n",
      "9999 -0.602927 -0.021368  0.885519  0\n",
      "\n",
      "[10000 rows x 4 columns]\n",
      "woe_rules for alice:\n",
      " {'variables': [{'name': 'x1', 'type': 'numeric', 'split_points': [-0.6048731088638305, -0.2093792676925656, 0.1864844083786014, 0.59245548248291], 'woes': [0.13818949789069251, 0.1043626580338657, 0.012473718947119546, -0.08312553911263658, -0.16055365315128886], 'ivs': [0.003719895277030594, 0.0021358508878795324, 3.104792224414912e-05, 0.001402356358949689, 0.005298781871642081], 'total_counts': [2000, 2000, 2000, 2000, 2000], 'else_woe': -0.7620562438001163, 'else_iv': 6.385923806287768e-05, 'else_counts': 0}, {'name': 'x2', 'type': 'numeric', 'split_points': [-0.6180597543716427, -0.21352910995483343, 0.18739376068115243, 0.5941788196563724], 'woes': [-0.5795513521445242, -0.17800092651085536, 0.02175062133493428, 0.32061945260518093, 0.5508555713857505], 'ivs': [0.07282166470764677, 0.006530977061248972, 9.424153452104671e-05, 0.019271267842100412, 0.05393057944296773], 'total_counts': [2000, 2000, 2000, 2000, 2000], 'else_woe': -0.7620562438001163, 'else_iv': 6.385923806287768e-05, 'else_counts': 0}, {'name': 'x3', 'type': 'numeric', 'split_points': [-0.5902724504470824, -0.19980529546737677, 0.2072824716567998, 0.6102998018264773], 'woes': [-0.5371125119817587, -0.25762552591997334, -0.022037294110497735, 0.3445721198562295, 0.6304998785437507], 'ivs': [0.062290057806557865, 0.013846189451494859, 9.751520288052276e-05, 0.022140413942886045, 0.06928418076454097], 'total_counts': [2000, 2000, 2000, 2000, 2000], 'else_woe': -0.7620562438001163, 'else_iv': 6.385923806287768e-05, 'else_counts': 0}]}\n",
      "woe_rules for bob:\n",
      " {'variables': [{'name': 'x18', 'type': 'numeric', 'split_points': [-0.595701837539673, -0.18646149635314926, 0.20281808376312258, 0.5969645977020259], 'woes': [0.7644870924575128, 0.3796894156855692, 0.09717493242210018, -0.3856750302449858, -0.6258460389655672], 'ivs': [0.0984553650514661, 0.026672043182215357, 0.0018543743703697353, 0.031572379289703925, 0.08527363286990977], 'total_counts': [2000, 2000, 2000, 2000, 2000], 'else_woe': -0.7620562438001163, 'else_iv': 6.385923806287768e-05, 'else_counts': 0}, {'name': 'x19', 'type': 'numeric', 'split_points': [-0.5988080263137814, -0.2046342611312865, 0.1958462238311768, 0.6044608354568479], 'woes': [-0.24268812281101115, -0.18886157950622262, 0.061543825157264156, 0.15773711862524092, 0.24528753075504478], 'ivs': [0.012260322787780317, 0.0073647296376464335, 0.0007489127774465146, 0.0048277504221347, 0.011464529974064627], 'total_counts': [2000, 2000, 2000, 2000, 2000], 'else_woe': -0.7620562438001163, 'else_iv': 6.385923806287768e-05, 'else_counts': 0}, {'name': 'x20', 'type': 'numeric', 'split_points': [-0.6013513207435608, -0.2053116083145139, 0.19144065380096467, 0.5987063169479374], 'woes': [1.1083043875152403, 0.5598579367731444, 0.15773711862524092, -0.4618210945247346, -0.9083164208649596], 'ivs': [0.1887116758575296, 0.055586120695474514, 0.0048277504221347, 0.04568212645465593, 0.18279475271010598], 'total_counts': [2000, 2000, 2000, 2000, 2000], 'else_woe': -0.7620562438001163, 'else_iv': 6.385923806287768e-05, 'else_counts': 0}]}\n",
      "substituted ds in alice:\n",
      "             x1        x2        x3\n",
      "0     0.104363  0.550856 -0.537113\n",
      "1     0.138189  0.320619 -0.537113\n",
      "2    -0.160554  0.021751 -0.537113\n",
      "3     0.138189  0.021751  0.630500\n",
      "4     0.012474  0.320619  0.630500\n",
      "...        ...       ...       ...\n",
      "9995  0.104363 -0.178001  0.344572\n",
      "9996  0.012474  0.550856 -0.257626\n",
      "9997  0.104363  0.550856 -0.257626\n",
      "9998 -0.083126 -0.178001  0.630500\n",
      "9999  0.138189 -0.579551 -0.022037\n",
      "\n",
      "[10000 rows x 3 columns]\n",
      "substituted ds in bob:\n",
      "            x18       x19       x20  y\n",
      "0    -0.625846  0.061544 -0.908316  1\n",
      "1    -0.385675  0.157737 -0.461821  1\n",
      "2     0.097175 -0.242688 -0.461821  0\n",
      "3     0.764487  0.157737 -0.908316  1\n",
      "4     0.379689  0.157737  0.157737  1\n",
      "...        ...       ...       ... ..\n",
      "9995  0.764487  0.061544 -0.908316  1\n",
      "9996  0.379689  0.061544  1.108304  1\n",
      "9997  0.379689  0.157737 -0.908316  1\n",
      "9998 -0.625846 -0.188862 -0.908316  1\n",
      "9999  0.764487  0.061544 -0.908316  0\n",
      "\n",
      "[10000 rows x 4 columns]\n"
     ]
    }
   ],
   "source": [
    "# woe binning use SPU or HEU device to protect label\n",
    "spu = sf.SPU(sf.utils.testing.cluster_def(['alice', 'bob']))\n",
    "\n",
    "# Only support binary classification label dataset for now.\n",
    "# use linear dataset as example\n",
    "from secretflow.utils.simulation.datasets import load_linear\n",
    "vdf = load_linear(parts={alice: (1, 4), bob: (18, 22)})\n",
    "print(f\"orig ds in alice:\\n {sf.reveal(vdf.partitions[alice].data)}\")\n",
    "print(f\"orig ds in bob:\\n {sf.reveal(vdf.partitions[bob].data)}\")\n",
    "\n",
    "from secretflow.preprocessing.binning.vert_woe_binning import VertWoeBinning\n",
    "\n",
    "binning = VertWoeBinning(spu)\n",
    "woe_rules = binning.binning(\n",
    "    vdf,\n",
    "    binning_method=\"quantile\",\n",
    "    bin_num=5,\n",
    "    bin_names={alice: [\"x1\", \"x2\", \"x3\"], bob: [\"x18\", \"x19\", \"x20\"]},\n",
    "    label_name=\"y\",\n",
    ")\n",
    "\n",
    "print(f\"woe_rules for alice:\\n {sf.reveal(woe_rules[alice])}\")\n",
    "print(f\"woe_rules for bob:\\n {sf.reveal(woe_rules[bob])}\")\n",
    "\n",
    "from secretflow.preprocessing.binning.vert_woe_substitution import VertWOESubstitution\n",
    "\n",
    "woe_sub = VertWOESubstitution()\n",
    "sub_data = woe_sub.substitution(vdf, woe_rules)\n",
    "\n",
    "print(f\"substituted ds in alice:\\n {sf.reveal(sub_data.partitions[alice].data)}\")\n",
    "print(f\"substituted ds in bob:\\n {sf.reveal(sub_data.partitions[bob].data)}\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "5dd4ddd5",
   "metadata": {},
   "source": [
    "## Ending"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "831cc32e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Clean up temporary files\n",
    "\n",
    "import os\n",
    "\n",
    "try:\n",
    "    os.remove(alice_path)\n",
    "    os.remove(bob_path)\n",
    "except OSError:\n",
    "    pass"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13 ('py3.8')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "vscode": {
   "interpreter": {
    "hash": "66d1547304beaba725027c44e57cc46fc747862fe9496520910412a3187eb35f"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
